from tqdm import tqdm
from agent import Agents
from common.replay_buffer import ReplayBuffer
import torch
import os
import numpy as np
import matplotlib.pyplot as plt


class Runner:
    def __init__(self, args, env):
        self.args = args
        self.eps = args.eps
        self.episode_limit = args.max_episode_len
        self.env = env
        self.agents = Agents(self.args)
        self.buffer = ReplayBuffer(self.args.buffer_size)
        self.save_path = self.args.save_dir
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)


    def run(self):
        losses, returns = [], []
        loss = 0
        for time_step in tqdm(range(self.args.time_steps)):
            episode, _ = self.run_episode(evaluate=False)
            sample = self.buffer.convert(episode)
            loss0 = self.agents.learn(sample)
            loss += np.sum(loss0.detach().cpu().numpy())

            if time_step > 0 and time_step % self.args.evaluate_rate == 0:
                reward = 0
                with torch.no_grad():
                    for _ in range(self.args.evaluate_episodes):
                        reward += self.run_episode(evaluate=True)[1]
                returns.append(reward / self.args.evaluate_episodes)
                losses.append(loss / self.args.evaluate_rate)
                loss = 0

                fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

                axs[0].plot(losses)
                axs[0].set_title('Loss')

                # 绘制返回值子图
                axs[1].plot(returns)
                axs[1].set_title('Returns')

                plt.xlabel('Epoch({} updates)'.format(self.args.evaluate_rate))
                plt.savefig(self.save_path + '/plt.png', format='png')

                plt.close()
                np.save(self.save_path + '/returns', returns)
                np.save(self.save_path + '/loss', losses)


    def run_episode(self, evaluate=False):

        s, available_actions = self.env.reset()
        step = 0
        records = []
        for env_id in range(self.args.vec_env):
            episode_dict = {'o': [],
                            'u': [],
                            'a_u': [],
                            'r': [],
                            'done': [],
                            }
            records.append(episode_dict)

        reward = 0

        hidden = self.agents.init_hidden()

        actions, hidden1 = self.agents.select_actions(s, available_actions, hidden, 0 if evaluate else self.eps)

        while step < self.episode_limit:

            s_next, r, done, available_actions_next = self.env.step(actions)

            actions_next, hidden2 = self.agents.select_actions(s_next, available_actions_next, hidden1, 0 if evaluate else self.eps)

            reward += sum(r[:, 0])/self.args.vec_env

            for env_id in range(self.args.vec_env):
                records[env_id]['o'].append(s[env_id])
                # print(np.array(records[env_id]['o']).shape, s[env_id].shape)
                records[env_id]['u'].append(actions[env_id])
                records[env_id]['a_u'].append(available_actions[env_id])
                records[env_id]['r'].append(r[env_id, 0])
                records[env_id]['done'].append(done[env_id])

            s = s_next
            actions = actions_next
            available_actions = available_actions_next

            hidden1 = hidden2

            step += 1

        return records, reward

